6  Supplemental Code

6.1 High Dimensional Mediation Analysis Code

6.1.1 Early Integration of omics datasets

Code
#' Conducts High Dimensional Mediation Analysis with Early Integration
#' 
#' Given exposure, outcome and multiple omics data,
#' function runs HIMA mediation analysis with early integration
#' and returns a tidy dataframe for the results
#'
#' @param exposure A numeric vector for the exposure variable
#' @param outcome A numeric vector for the outcome variable
#' @param omics A list of numeric matrices representing omics data
#' @param covs A numeric matrix representing the covariates
#'
#' @return A tidy dataframe summarizing the results of HIMA analysis
#'
#' @import dplyr
#' @importFrom tidyr as_tibble
#' @importFrom dplyr left_join
#' @importFrom janitor remove_empty
#' @importFrom purrr map_lgl
#' @importFrom stringr str_detect
#' @importFrom stats gaussian
#' @importFrom HIMA hima
#' @importFrom base as.matrix
hima_early_integration <- function(exposure, 
                                   outcome, 
                                   omics_lst, 
                                   covs, 
                                   Y.family = "binomial",
                                   M.family = "gaussian") {
  # Give error if covs is NULL
  if (is.null(covs)) {
    stop("Currently, hima does not support analysis without covariates.
         Please provide covariates.")
  }
  
  # Combines omics data into one dataframe
  omics_lst_df <- purrr::map(omics_lst, ~as_tibble(.x, rownames = "name"))
  
  meta_df <- imap_dfr(omics_lst_df, ~tibble(omic_layer = .y, ftr_name = names(.x)))%>%
    filter(ftr_name != "name") %>%
    mutate(omic_num = case_when(str_detect(omic_layer, "meth") ~ 1, 
                                str_detect(omic_layer, "transc") ~ 2, 
                                str_detect(omic_layer, "miR") ~ 3,
                                str_detect(omic_layer, "pro") ~ 4, 
                                str_detect(omic_layer, "met") ~ 5))
  
  # Create data frame of omics data
  omics_df <- omics_lst_df  %>% 
    purrr::reduce(left_join, by = "name") %>%
    column_to_rownames("name")
  
  # Run hima
  result_hima_early <- hima(X = exposure,
                            Y = outcome,
                            M = omics_df,
                            COV.XM = covs,
                            COV.MY = covs,
                            Y.family = Y.family,
                            M.family = M.family,
                            verbose = FALSE, 
                            max.iter = 100000, 
                            scale = FALSE) %>%
    as_tibble(rownames = "ftr_name")
  
  # Reorders the columns and adds the omics layer information
  result_hima_early <- result_hima_early %>%
    dplyr::mutate(
      multiomic_mthd = "Early Integration",
      mediation_mthd = "HIMA") %>%
    dplyr::select(multiomic_mthd, mediation_mthd, 
                  ftr_name, 
                  everything())
  # Filter to significant features only and scale % total effect to 100
  result_hima_early <- result_hima_early %>% 
    filter(BH.FDR < 0.05) %>%
    mutate(pte = 100*`% total effect`/sum(`% total effect`), 
           sig = if_else(BH.FDR < 0.05, 1, 0)) %>%
    rename(ie = 'alpha*beta', 
           `TE (%)` = pte)
  
  # Merge results with feature metadata 
  result_hima_early <- result_hima_early %>% 
    left_join(meta_df, by = "ftr_name")
  
  # Return result
  return(result_hima_early)
}

6.1.2 Intermediate Integration of omics datasets

Code
#' Conducts High Dimensional Mediation Analysis with Intermediate Integration
#' Combines -omic data with covariates to calculate the indirect effect 
#' of each individial, possible mediating feature on the relationship
#' between exposure and outcome using cooperative group lasso 
#'
#' @param omics list of dataframes with -omic data
#' @param covs dataframe with covariate data
#' @param outcome vector with outcome variable data
#' @param exposure vector with exposure variable data
#' @param n_boot number indicating number of bootstrap estimates to perform for se
#'
#' @return dataframe with the following variables: ftr_name, omic_layer, 
#' alpha, alpha_se, s1, beta_bootstrap, beta_se, indirect, ind_effect_se, 
#' lcl, ucl, gamma, pte_intermediate, sig_intermediate
#' 
#' @examples
#' calculate_mediation(omics = list(omics_1, omics_2),
#'                     covs = covariate_data, 
#'                     outcome = outcome_data, 
#'                     exposure = exposure_data. 
#'                     n_boot = 100)
#' 
#' @importFrom epiomics owas
#' @importFrom dplyr bind_cols group_by inner_join mutate 
#' @importFrom dplyr select rename filter summarise 
#' @importFrom purrr map map2 reduce 
#' @importFrom broom tidy 
#' @importFrom matrixStats colAnyNA rowAnyNA 
#' @importFrom RMediation medci
#' @importFrom boot boot detectCores
#' @importFrom glmnet::groupedlasso grouped.lasso::groupedlasso
#' @importFrom glmnet::cv.groupedlasso cv.groupedlasso::cv.groupedlasso
#' @importFrom glmnet::cvglmnet cvglmnet::cvglmnet
#' @importFrom glmnet::predict.cv.glmnet predict.cv.glmnet
#' @importFrom matrixStats colAlls rowSds rowMeans
#' @importFrom stringr str_detect
hima_intermediate_integration <- function(exposure, 
                                          outcome, 
                                          omics_lst,   
                                          covs = NULL, 
                                          n_boot, 
                                          Y.family = "gaussian") {
  ## Change omics elements to dataframes 
  omics_lst_df <- purrr::map(omics_lst, ~as_tibble(.x, rownames = "name"))
  
  meta_df <- imap_dfr(omics_lst_df, ~tibble(omic_layer = .y, ftr_name = names(.x)))%>%
    filter(ftr_name != "name") %>%
    mutate(omic_num = case_when(str_detect(omic_layer, "meth") ~ 1, 
                                str_detect(omic_layer, "transc") ~ 2, 
                                str_detect(omic_layer, "miR") ~ 3,
                                str_detect(omic_layer, "pro") ~ 4, 
                                str_detect(omic_layer, "met") ~ 5))
  
  ## Create data frame of omics data
  omics_df <- omics_lst_df  %>% 
    purrr::reduce(left_join, by = "name") %>%
    column_to_rownames("name")
  
  # Rename family for xtune function
  if(Y.family != "gaussian") {stop("Only continuous outcomes currently supported")}
  
  # Get dataframe of all data
  full_data <- tibble(outcome = outcome, 
                      exposure = exposure) %>% 
    bind_cols(omics_df)
  
  # Add covs if not null
  if(!is.null(covs)) {full_data <- full_data %>% bind_cols(covs)}
  
  # Get external information matrix
  # Convert each data frame to a long format and extract the unique column names
  external_info <- purrr::map(omics_lst, 
                              ~data.frame(column = colnames(.x), val = 1)) %>%
    map2(names(omics_lst), ~dplyr::rename(.x, !!.y := val)) %>%
    purrr::reduce(., full_join, by = "column") %>%
    replace(is.na(.), 0) %>%
    column_to_rownames("column")
  
  # 0) calculate gamma (x --> y) ----
  if(Y.family == "binary"){
    gamma_est <- coef(glm(outcome ~ exposure + as.matrix(covs), 
                          family = binomial))[["exposure"]]  
  } else if(Y.family == "gaussian"){ 
    gamma_est <- coef(lm(outcome ~ exposure ))[["exposure"]]  
  }
   
  # 1) X --> M ----------
  # Model 1: x --> m
  x_m_reg <- epiomics::owas(df = full_data,
                            omics = rownames(external_info),
                            covars = colnames(covs),
                            var = "exposure",
                            var_exposure_or_outcome = "exposure") %>%
    dplyr::select(feature_name, estimate, se) %>%
    dplyr::rename(alpha = estimate, 
                  alpha_se = se)
  
  # 2) M--> Y: select features associated with the outcome using group lasso ----
  # x+M-->Y Glasso
  X = as.matrix(full_data[, colnames(omics_df)])
  Y = full_data$outcome
  Z = as.matrix(external_info)
  U = as.matrix(full_data[,"exposure"])
  if(is.null(covs)){ 
    U = as.matrix(full_data[,"exposure"])
  } else { 
    U = as.matrix(full_data[,c(colnames(covs), "exposure")])  
  }
  
  # Run xtune
  invisible(
    capture.output(
      xtune.fit_all_data <- xtune(X = X, Y = Y, Z = Z, U = U,
                                  c = 1, 
                                  family = "linear")
    )
  )
  
  # Extract estimates
  xtune_betas_all_data <- as_tibble(as.matrix(xtune.fit_all_data$beta.est),
                                    rownames = "feature_name") %>% 
    left_join(meta_df, by = c("feature_name" = "ftr_name")) %>%
    dplyr::filter(feature_name %in% colnames(omics_df))
  
  # 3) Calculate SE for Model 3: x+m to y reg -----
  # 3.1) boot function --------------------------------------------------------
  group_lasso_boot <- function(data, indices, external_info, covs = NULL) {
    X = as.matrix(data)[indices, rownames(external_info)]
    Y = data$outcome[indices]
    if(is.null(covs)){
      U = as.matrix(data[indices,"exposure"])
    } else {
      U = as.matrix(data[indices,c(colnames(covs), "exposure")])
    }
    # Run xtune with error catching function, sometimes xtune needs to run again
    success <- FALSE
    attempts <- 0
    while(!success & attempts < 10) {
      tryCatch({
        # Run xtune
        xtune.fit <- xtune(X = X, Y = Y, Z = as.matrix(external_info), U = U,
                           sigma.square = estimateVariance(X,Y), 
                           c = 0, 
                           family = "linear", message = FALSE)
        
        
        # If the xtune call is successful, proceed with the rest of the code
        # Select betas, drop intercept
        xtune_betas <- as_tibble(as.matrix(xtune.fit$beta.est),
                                 rownames = "feature_name") %>% 
          dplyr::filter(feature_name %in% colnames(omics_df)) %>%
          dplyr::select(s1) %>%
          as.matrix()
        # Fix issue where sometimes lasso returns a null matrix
        if(sum(dim(xtune_betas) == c(nrow(external_info), 1))==2){
          return(xtune_betas)
        } else {
          return(as.matrix(rep(0, nrow(external_info))))
        }
        
        success <- TRUE
      }, error = function(e) {
        attempts <- attempts + 1
        print(paste0("Encountered error: ", e$message))
        if (attempts < 10) {
          print("Retrying...")
        } else {
          print("Maximum number of attempts reached. Exiting...")
          return(NULL)
        }
      })
    }
  }
  
  
  # 3.2) Run Bootstrap analysis ----------------------
  # Run Bootstrap
  if(is.null(covs)){ 
    boot_out <- boot(data = full_data,
                     statistic = group_lasso_boot,
                     R = n_boot,
                     # strata = as.numeric(full_data$h_cohort),
                     ncpus = detectCores(),
                     parallel = "multicore", 
                     external_info = external_info)
  } else {
    boot_out <- boot(data = full_data,
                     statistic = group_lasso_boot,
                     R = n_boot,
                     # strata = as.numeric(full_data$h_cohort),
                     ncpus = detectCores(),
                     parallel = "multicore", 
                     external_info = external_info, 
                     covs = covs)
  }
  
  
  # Calculate percent of times feature was selected
  glasso_boot_results <- tibble(
    feature_name = rownames(external_info), 
    beta_bootstrap = colMeans(replace_na(boot_out$t, 0)), 
    beta_se = apply(replace_na(boot_out$t, 0), 2, sd)) %>%
    left_join(meta_df, by = c("feature_name" = "ftr_name"))
  
  # 3.4) Join unpenalized results with glasso results ----
  int_med_coefs <- dplyr::inner_join(xtune_betas_all_data, 
                                     glasso_boot_results, 
                                     by = c("feature_name", 
                                            "omic_layer", "omic_num")) %>% 
    dplyr::inner_join(x_m_reg, by = "feature_name")
  
  # Calculate confidence intervals -----
  # mu.x: a1 from reg m = a0 + a1*X
  # mu.y: b2 from reg y = b0 + b1*X + b2*M
  int_med_res <- int_med_coefs %>%
    group_by(feature_name) %>% 
    nest() %>%
    mutate(res = purrr::map(data, 
                            ~RMediation::medci(mu.x = .x$alpha, 
                                               se.x = .x$alpha_se, 
                                               mu.y = .x$beta_bootstrap,
                                               se.y = .x$beta_se,  
                                               type = "dop") %>% 
                              unlist() %>% t() %>% as_tibble())) %>%
    unnest(c(res, data)) %>%
    ungroup() %>%
    janitor::clean_names()
  
  # Modify results 
  intermediate_int_res <- int_med_res %>%
    janitor::clean_names() %>%
    rename(indirect = "estimate", 
           ind_effect_se = "se", 
           lcl = x95_percent_ci1,
           ucl = x95_percent_ci2) %>% 
    mutate(gamma = gamma_est, 
           pte = (indirect)/gamma, 
           sig = if_else(lcl>0|ucl<0, 1, 0))
  
    # Filter to significant features only and scale % total effect to 100
  intermediate_int_res <- intermediate_int_res %>% 
    filter(sig == 1) %>%
    mutate(pte = 100*pte/sum(pte))
  
  
  # Rename feature name
  intermediate_int_res <- intermediate_int_res %>% 
    dplyr::rename(ftr_name = feature_name,
                  ie = indirect, 
                  beta = beta_bootstrap, 
                  `TE (%)` = pte)
  
  return(intermediate_int_res)
}

6.1.3 Late Integration of omics datasets

Code
#' Conducts High Dimensional Mediation Analysis with Late Integration
#' Performs HIMA mediation analysis on multiple omics layers with late integration
#'
#' This function uses HIMA (High-dimensional Mediation Analysis) to perform 
#' mediation analyses on multiple omics layers. For each omics layer, it runs 
#' HIMA with exposure, outcome, covariates, and that specific omics layer as 
#' input, and then collects the results as a tibble. The results for each type
#' of omics layer are stored in a list and then concatenated to form the final 
#' result. The values are scaled to a percentage of total effect and metadata 
#' is joined to fill in missing information.
#'
#' @param exposure a numeric vector containing exposure measurements
#' @param outcome a numeric vector containing outcome measurements
#' @param omics a named list containing multiple omics layers
#' @param covs a data frame containing covariate information
#' @param omics_names a data frame with metadata for each omics layer
#' @return a tibble with High-dimensional Mediation Analysis results, including metadata
#' @importFrom dplyr mutate select left_join across
#' @importFrom janitor remove_empty
#' @importFrom HIMA hima
#' @importFrom purrr map
#' @importFrom stats var
#'
#' @export
hima_late_integration <- function(exposure,
                                  outcome,
                                  omics_lst,
                                  covs, 
                                  Y.family,
                                  M.family = "gaussian") {
  
  # Give error if covs is NULL
  if (is.null(covs)) {
    stop("Currently, hima does not support analysis without covariates.
         Please provide covariates.")
  }
  
  # Get number of omics layers
  n_omics <- length(omics_lst)
  omics_name <- names(omics_lst)
  
  # Meta data
  omics_lst_df <- purrr::map(omics_lst, ~as_tibble(.x, rownames = "name"))
  
  meta_df <- imap_dfr(omics_lst_df, ~tibble(omic_layer = .y, ftr_name = names(.x)))%>%
    filter(ftr_name != "name") %>%
    mutate(omic_num = case_when(str_detect(omic_layer, "meth") ~ 1, 
                                str_detect(omic_layer, "transc") ~ 2, 
                                str_detect(omic_layer, "miR") ~ 3,
                                str_detect(omic_layer, "pro") ~ 4, 
                                str_detect(omic_layer, "met") ~ 5))
  # Start the computation
  result_hima_late <- vector(mode = "list", length = n_omics)
  for(i in 1:n_omics) {
    # Run HIMA with input data
    result_hima_late[[i]] <- hima(X = exposure,
                                  Y = outcome,
                                  M = omics_lst[[i]],
                                  COV.XM = covs,
                                  Y.family = Y.family,
                                  M.family = M.family, 
                                  max.iter = 100000, 
                                  scale = FALSE) %>%
      as_tibble(rownames = "ftr_name")
  }
  
  # Assign omic names
  names(result_hima_late) <- names(omics_lst)
  
  # Concatenate the resulting data frames
  result_hima_late_df <- bind_rows(result_hima_late, .id = "omic_layer")
  
  # Add key details
  result_hima_late_df <- result_hima_late_df %>%
    dplyr::mutate(
      multiomic_mthd = "Late Integration",
      mediation_mthd = "HIMA") %>%
    dplyr::select(multiomic_mthd, mediation_mthd, 
                  omic_layer, ftr_name, 
                  everything())
  
  # Filter to significant features only and scale % total effect to 100
  result_hima_late_df <- result_hima_late_df %>% 
    filter(BH.FDR < 0.05) %>%
    mutate(pte = 100*`% total effect`/sum(`% total effect`), 
           sig = if_else(BH.FDR < 0.05, 1, 0)) %>%
    rename(ie = 'alpha*beta', 
           `TE (%)` = pte)
  
  # Return the final table
  return(result_hima_late_df)
}

6.1.4 Plot HIMA

Code
#' Plot of High Dimensional Mediation Analysis
#' 
#' Given a tidy dataframe summarizing the results of HIMA analysis
#'
#' @param result_hima  a tidy dataframe summarizing the results of HIMA analysis
#'
#' @return a figure of the result of HIMA analysis
#'
#' @import dplyr
#' @importFrom ggplot2 ggplot
#' 
plot_hima <- function(result_hima) {
  # Pivot longer for figure
  result_hima_long <- result_hima %>% 
    rename(Alpha = alpha,
           Beta = beta) %>%
    pivot_longer(cols = c(Alpha, Beta,`TE (%)`), 
                 names_to = "name") %>%
    mutate(name = factor(name, levels = c("Alpha", "Beta", "TE (%)")))
  
  # Plot features
  p <- ggplot(result_hima_long, 
         aes(x = fct_inorder(ftr_name), 
             y = value,
             fill = omic_layer)) + 
    geom_bar(stat = "identity") +
    facet_grid(name ~ omic_layer, 
               scales = "free",
               space = "free_x") +
    scale_fill_brewer(type = "qual", palette = 2) +
    geom_hline(yintercept = 0, linetype = 1, color = "grey50") + 
    ylab(NULL) + xlab(NULL) +
    theme(
      strip.background = element_blank(),
      strip.text.x = element_blank(),
      axis.text.x = element_text(angle = 90, hjust = 1, vjust = .5),
      legend.title = element_blank(), 
      legend.position = "bottom", # Place the legend at the bottom
      legend.justification = c(1, 0))
  
  return(p)
}

6.2 Mediation with Latent Factors Analysis Code

6.2.1 Early Integration of omics datasets

Code
#' Conduct principal component analysis (PCA) as a dimensional reduction step 
#' and selected the top i principal components which explained >80% of the variance. 
#' Following the joint dimensional reduction step, conduct mediation analysis.

#' Given exposure, outcome and multiple omics data,
#' function runs HIMA mediation analysis with latent factors in early integration
#' and returns a list including two tidy dataframes for the results
#'
#' @param exposure A numeric vector for the exposure variable
#' @param outcome A numeric vector for the outcome variable
#' @param omics_lst A list of numeric matrices representing omics data
#' @param covs A numeric matrix representing the covariates
#'
#' @return A list including two tidy dataframes summarizing the results of HIMA analysis
#' and one vector indicating integration type 
#' one dataframe including the significant result of HIMA analysis with PCs as mediators.
#' another dataframe including the result of feature correlation with significant PCs.
#' 
#' @import dplyr
#' @importFrom tidyr as_tibble
#' @importFrom dplyr left_join
#' @importFrom purrr map
#' @importFrom stats prcomp cor
#' @importFrom HIMA hima
#' @importFrom base cumsum min sum scale apply
#' @importFrom stringr str_replace
#'
med_lf_early <- function(exposure,
                         outcome,
                         omics_lst,
                         covs,
                         Y.family = "gaussian",
                         M.family = "gaussian",
                         fdr.level = 0.05) {
  
  # Give error if covs is NULL
  if (is.null(covs)) {
    stop("Currently, hima does not support analysis without covariates.
         Please provide covariates.")
  }
  
  # Combines omics data into one dataframe
  omics_lst_df <- purrr::map(omics_lst, ~as_tibble(.x, rownames = "name"))
  
  # Meta data
  meta_df <- imap_dfr(omics_lst_df, ~tibble(omic_layer = .y, ftr_name = names(.x)))%>%
    filter(ftr_name != "name") %>%
    mutate(omic_num = case_when(str_detect(omic_layer, "meth") ~ 1, 
                                str_detect(omic_layer, "transc") ~ 2, 
                                str_detect(omic_layer, "miR") ~ 3,
                                str_detect(omic_layer, "pro") ~ 4, 
                                str_detect(omic_layer, "met") ~ 5))
  
  # i. Obtain PCs----
  omics_df_pca <- prcomp(omics_df, center = TRUE, scale. = TRUE)
  
  # Calculate variance and proportion of variance explained by each PC
  vars <- apply(omics_df_pca$x, 2, var)
  props <- vars / sum(vars)
  cum_props <- cumsum(props)
  
  # Determine the number of PCs needed to explain > 80% of the total variance
  n_80_pct <- min((1:length(cum_props))[cum_props > 0.8])
  
  # The first PCs which explain >80% of the variance are used as latent mediators
  PCs <- omics_df_pca$x[, 1:n_80_pct] %>% scale()
  
  # ii. Perform HIMA with PCs as mediator----
  
  result_hima_comb_pc <- hima(X  = exposure,
                               Y = outcome,
                               M = PCs,
                               COV.XM = covs,
                               COV.MY = covs,
                               Y.family =  Y.family,
                               M.family = M.family, 
                               scale = FALSE)
  
  
  # Change to tibble and select significant PCs
  result_hima_pca_early <- as_tibble(result_hima_comb_pc, rownames = "lf_num") %>% 
    filter(BH.FDR < fdr.level) 
  
  # Filter Significant PCs, Create scaled %TE variable
  result_hima_pca_early_sig <- result_hima_pca_early %>% 
    mutate(
      te_direction = if_else(beta<0, -1*`% total effect`, `% total effect`), 
      `% Total Effect scaled` = 100*`% total effect`/sum(`% total effect`) %>%
        round(1), 
      multiomic_mthd = "Early") %>%
    mutate(lf_named = str_replace(lf_num, "PC", "Joint Comp. "),
           lf_ordered = forcats::fct_reorder(lf_named, `te_direction`)) %>%
    rename(Alpha = alpha, 
           Beta = beta,
           `TE (%)` = `% Total Effect scaled`) %>%
    mutate(omic_num = case_when(str_detect(lf_num, "meth") ~ 1, 
                                str_detect(lf_num, "transc") ~ 2, 
                                str_detect(lf_num, "miR") ~ 3,
                                str_detect(lf_num, "pro") ~ 4, 
                                str_detect(lf_num, "met") ~ 5,
                                TRUE ~ 0))
  
  
  # ii correlation of features vs PC's --------------
  # Extract variable correlation with principal components
  var.cor <- cor(omics_df, omics_df_pca$x)
  
  # Select only significant PCs
  ftr_cor_sig_pcs <- var.cor[,(colnames(var.cor) %in% 
                                 result_hima_pca_early_sig$lf_num)]
  
  ftr_cor_sig_pcs_df <- ftr_cor_sig_pcs %>%
    as_tibble(rownames = "feature") %>%
    left_join(meta_df, by = c("feature" = "ftr_name"))
    
  
  res = list(result_hima_pca_early_sig = result_hima_pca_early_sig, 
             result_ftr_cor_sig_pcs_early = ftr_cor_sig_pcs_df,
             integration_type = " Early")
  return(res)
}

6.2.2 Intermediate Integration of omics datasets

Code
#' Perform a joint dimensionality reduction step 
#' using Joint and Individual Variance Explained (JIVE). 
#' Following the joint dimensionality reduction step, 
#' conduct mediation analysis.


#' Given exposure, outcome and multiple omics data,
#' function runs HIMA mediation analysis with latent factors in intermediate integration
#' and returns a list including two tidy dataframes for the results
#'
#' @param exposure A numeric vector for the exposure variable
#' @param outcome A numeric vector for the outcome variable
#' @param omics_lst A list of numeric matrices representing omics data
#' @param jive.rankJ Number of joint factors for JIVE to estimate. If NULL, 
#' then jive estimates the optimum number of joint factors.
#' @param jive.rankA Number of individual factors for JIVE to estimate. If NULL,
#' then jive estimates this variable. Should be a numeric vector with the same 
#' length as dim(omics_lst)   
#' @param covs A numeric matrix representing the covariates
#'
#' @return A list including two tidy dataframes summarizing the results of HIMA analysis
#' and one vector indicating integration type.
#' one dataframe including the significant result of 
#' HIMA analysis with latent factors as mediators.
#' another dataframe including the result of feature correlation with significant PCs.
#' 
#' @import dplyr
#' @importFrom tidyr as_tibble
#' @importFrom dplyr left_join
#' @importFrom r.jive jive
#' @importFrom stats prcomp cor
#' @importFrom HIMA hima
#' @importFrom base min sum scale apply svd diag
#' @importFrom stringr str_replace
#'
med_lf_intermediate <- function(exposure, 
                                outcome,
                                omics_lst,
                                covs,
                                jive.rankJ = NULL,
                                jive.rankA = NULL,  
                                Y.family = "gaussian",
                                M.family = "gaussian",
                                fdr.level = 0.05) {

  # Give error if covs is NULL
  if (is.null(covs)) {
    stop("Currently, hima does not support analysis without covariates.
         Please provide covariates.")
  }
  
  # Combines omics data into one dataframe
  omics_lst_df <- purrr::map(omics_lst, ~as_tibble(.x, rownames = "name"))
  
  # Meta data
  meta_df <- imap_dfr(omics_lst_df, ~tibble(omic_layer = .y, ftr_name = names(.x)))%>%
    filter(ftr_name != "name") %>%
    mutate(omic_num = case_when(str_detect(omic_layer, "meth") ~ 1, 
                                str_detect(omic_layer, "transc") ~ 2, 
                                str_detect(omic_layer, "miR") ~ 3,
                                str_detect(omic_layer, "pro") ~ 4, 
                                str_detect(omic_layer, "met") ~ 5))
  # Transpose omics matrices
  omics_t <- lapply(omics_lst, t)
  
  # Rename omics datasets
  names(omics_t) = names(omics_lst)
  
  # Run JIVE----- 
  if (is.null(jive.rankJ) | is.null(jive.rankA)) {
    jive.message <- "Step A) Starting JIVE analysis..."
  } else {jive.message <- "Step A) Starting JIVE analysis with ranks given..."}
  message(paste0(jive.message, "    (", 
                 Sys.time() %>% 
                   format("%H:%M:%S") %>% 
                   str_split(":") %>% 
                   unlist() %>% 
                   .[1:3] %>% 
                   paste(collapse = ":"), 
                 ")\n"))
  
  result_jive2 <- jive(data = omics_t,
                       rankJ = jive.rankJ,
                       rankA = jive.rankA,
                       method = "given",
                       conv = 1e-04,
                       maxiter = 100,
                       showProgress = FALSE)
  
  # Get the components from JIVE
  # Function to extract joint and individual PCs ----
  get_PCs_jive <- function (result){
    # Number of joint factors 
    n_joint = result$rankJ
    # Number of individual factors
    n_indiv = result$rankA
    # Get number of data matrices in result
    l <- length(result$data)    
    # Calculate total number of PCs to compute
    nPCs = n_joint + sum(n_indiv)    
    # Initialize matrix to hold PC scores
    PCs = matrix(nrow = nPCs, ncol = dim(result$data[[1]])[2])    
    # Initialize vector to hold PC names
    PC_names = rep("", nPCs)    
    # If joint structure is present
    if (n_joint > 0) {   
      # Compute SVD on joint structure
      SVD = svd(do.call(rbind, result$joint), nu = n_joint, nv = n_joint)
      # Compute PC scores for joint structure
      PCs[1:n_joint,] = diag(SVD$d)[1:n_joint, 1:n_joint] %*% t(SVD$v[, 1:n_joint])    
      # Assign names to joint PCs
      PC_names[1:n_joint] = paste("Joint ", 1:n_joint)    
    }
    # Loop over data matrices
    for (i in 1:l) {    
      # If individual structure is present for this matrix
      if (n_indiv[i] > 0) {    
        # Compute SVD on individual structure
        SVD = svd(result$individual[[i]], nu = n_indiv[i], 
                  nv = n_indiv[i])    
        # Get indices for PCs corresponding to this data matrix
        indices = (n_joint + sum(n_indiv[0:(i - 1)]) + 1):(n_joint + 
                                                             sum(n_indiv[0:i]))    
        # Compute PC scores for individual structure
        PCs[indices, ] = diag(SVD$d)[1:n_indiv[i], 1:n_indiv[i]] %*% 
          t(SVD$v[, 1:n_indiv[i]])   
        # Assign names to individual PCs
        PC_names[indices] = paste0(names(result$data)[i], 
                                   "_", 1:n_indiv[i]) 
      }
    }
    # Rename PCs
    rownames(PCs) <- PC_names %>%
      str_replace("  ", "_")
    # Transpose and change to data.frame
    out <- as.data.frame(t(PCs))
    # Return output
    return(out)
  }
  
  factors_jive <- get_PCs_jive(result_jive2) %>% 
    dplyr::mutate(across(everything(), ~as.vector(scale(.))))
  
  # run mediation analysis------
  message(paste0("Step B) Starting hima on JIVE factors...    (", 
                 Sys.time() %>% 
                   format("%H:%M:%S") %>% 
                   str_split(":") %>% 
                   unlist() %>% 
                   .[1:3] %>% 
                   paste(collapse = ":"), 
                ")\n"))
  
  # Select only HH:MM:SS from Sys.time()
  
  
  
  result_hima_jive <- hima(X = exposure,
                           Y = outcome,
                           M = factors_jive,
                           COV.MY = covs,
                           COV.XM = covs,
                           Y.family = c("gaussian"),
                           M.family = c("gaussian"), 
                           verbose = FALSE,
                           scale = FALSE)
  
  # Modify and filter significant                
  result_hima_jive_1 <- result_hima_jive %>% 
    rownames_to_column("lf_num") %>% 
    mutate(multiomic_mthd = "Intermediate",
           ind_joint = str_split_fixed(lf_num, fixed("_"), 2)[,1], 
           ind_joint_num = case_when(ind_joint == "Joint" ~ 1, 
                                     ind_joint == "methylome" ~ 2,
                                     ind_joint == "transcriptome" ~ 3)) %>% 
    dplyr::select(multiomic_mthd, everything())
  
  # Filter significant components and create scaled %TE variable
  result_hima_jive_sig <-  result_hima_jive_1 %>% 
    filter(BH.FDR<fdr.level) %>%   
    mutate(`% Total Effect scaled` =
             round(100*`% total effect`/sum(`% total effect`),1),
           te_direction = if_else(beta < 0, 
                                  -1 * `% total effect`, 
                                  `% total effect`)) %>%
    mutate(lf_named = str_replace(toTitleCase(lf_num), "_", " Comp. "),
           lf_ordered = forcats::fct_reorder(lf_named, `te_direction`)) %>%
    rename(Alpha = alpha, 
           Beta = beta,
           `TE (%)` = `% Total Effect scaled`)%>%
    mutate(omic_num = case_when(str_detect(lf_num, "meth") ~ 1, 
                                str_detect(lf_num, "transc") ~ 2, 
                                str_detect(lf_num, "miR") ~ 3,
                                str_detect(lf_num, "pro") ~ 4, 
                                str_detect(lf_num, "met") ~ 5,
                                TRUE ~ 0))
  
  # correlation of features vs JIVE factors --------------
  # Extract variable correlation with JIVE Factors
  var.cor <- cor(omics_df, factors_jive) 
  
  # Select only significant PCs
  ftr_cor_sig_pcs_jive <- var.cor[,(colnames(var.cor) %in% 
                                          result_hima_jive_sig$lf_num)] 
  
  ftr_cor_sig_pcs_jive_df <- ftr_cor_sig_pcs_jive %>%
    as_tibble(rownames = "feature") %>%
    left_join(meta_df, by = c("feature" = "ftr_name"))
  
  res = list(result_hima_jive_sig = result_hima_jive_sig, 
             result_ftr_cor_sig_pcs_jive = ftr_cor_sig_pcs_jive_df,
             integration = "Intermediate")
  return(res)
}

6.2.3 Late Integration of omics datasets

Code
#' Conduct principal component analysis (PCA) as a dimensionality reduction step  
#' on each omics layer separately 
#' and selected the top i principal components which explained >80% of the variance. 
#' Following the joint dimensionality reduction step, conduct mediation analysis.

#' Given exposure, outcome and multiple omics data,
#' function runs HIMA mediation analysis with latent factors in late integration
#' and returns a list including two tidy dataframes for the results
#'
#' @param exposure A numeric vector for the exposure variable
#' @param outcome A numeric vector for the outcome variable
#' @param omics_lst A list of numeric matrices representing omics data
#' @param covs A numeric matrix representing the covariates
#'
#' @return A list including two tidy dataframes and one vector. 
#' Two tidy dataframes summarized the results of HIMA analysis
#' one dataframe including the significant result of 
#' HIMA analysis with PCs as mediators for each omics layer.
#' another dataframe including the result of feature correlation with significant PCs. 
#' One vector includes the integration type. 
#' 
#' @import dplyr
#' @importFrom tidyr as_tibble
#' @importFrom dplyr left_join
#' @importFrom purrr map2 mapreduce
#' @importFrom stats prcomp cor
#' @importFrom HIMA hima
#' @importFrom base cumsum min sum scale apply
#' @importFrom stringr str_replace
#'
med_lf_late <- function(exposure,
                        outcome,
                        omics_lst,
                        covs, 
                        Y.family = "gaussian",
                        M.family = "gaussian",
                        fdr.level = 0.05) {
  # Give error if covs is NULL
  if (is.null(covs)) {
    stop("Currently, hima does not support analysis without covariates.
         Please provide covariates.")
  }
  
  # Combines omics data into one dataframe
  omics_lst_df <- purrr::map(omics_lst, ~as_tibble(.x, rownames = "name"))
  
  # Meta data
  meta_df <- imap_dfr(omics_lst_df, ~tibble(omic_layer = .y, ftr_name = names(.x)))%>%
    filter(ftr_name != "name") %>%
    mutate(omic_num = case_when(str_detect(omic_layer, "meth") ~ 1, 
                                str_detect(omic_layer, "transc") ~ 2, 
                                str_detect(omic_layer, "miR") ~ 3,
                                str_detect(omic_layer, "pro") ~ 4, 
                                str_detect(omic_layer, "met") ~ 5))
  
  # Define function to run PCA on a matrix and return loadings and scores
  run_pca <- function(mat, omic_name) {
    # perform PCA on input matrix with scaling
    pca <- prcomp(mat, scale. = TRUE)
    # calculate variance explained by each principal component
    vars <- apply(pca$x, 2, var)
    props <- vars / sum(vars)
    # calculate cumulative proportion of variance explained
    cum_props <- cumsum(props)
    
    # determine the number of principal components needed to explain 80% of the variance
    n_80_pct <- min((1:length(cum_props))[cum_props > 0.8])
    # create a dataframe of scores for the principal components and scale them
    PCs <- pca$x[, 1:n_80_pct] %>% scale() %>% as.data.frame()
    colnames(PCs) <- paste0(omic_name, "_", colnames(PCs))
    # create a dataframe of loadings for the principal components and scale them
    loadings_df <- pca$rotation[, 1:n_80_pct] %>% scale() %>% as.data.frame()
    colnames(loadings_df) <- paste0(omic_name, "_", colnames(loadings_df))
    loadings_df <- rownames_to_column(loadings_df, "feature")
    
    # Including all PCs 
    # create a dataframe of scores for the principal components and scale them
    PCs_full <- pca$x %>% scale() %>% as.data.frame()
    colnames(PCs_full) <- paste0(omic_name, "_", colnames(PCs_full))
    
    # create a dataframe of proportion of variance explained by each principal component
    props_df <- data.frame(pc_num = paste0(omic_name, "_",
                                           names(props)),
                           pc_var_explained = props)
    rownames(props_df) <- NULL
    # return a list of results
    return(list(loadings = loadings_df, scores = PCs, scores_full = PCs_full,
                n_pcs_80_pct = n_80_pct, pc_var_explained = props_df))
  }
  # i. get PCs and HIMA -----
  # PCA scores: Apply function to each matrix in the list and collect 
  scores_list_late_int <-  map2(omics_lst, 
                                names(omics_lst), 
                                ~run_pca(.x, .y)$scores)
  scores_df <- purrr::reduce(scores_list_late_int, cbind) %>% as.data.frame()
  
  # Loadings: Apply function to each matrix in the list and collect PC 
  loadings_list_late_int <- map2(omics_lst, 
                                 names(omics_lst),
                                 ~run_pca(.x, .y)$loadings)
  loadings_df <- purrr::reduce(loadings_list_late_int, full_join, by = "feature")
  
  # Number of PCs explain >80%: Get number of PCs for each omic
  late_int_pcs_80 <- purrr::map(omics_lst, ~run_pca(.x, NULL)$n_pcs_80_pct) %>%
    bind_rows()
  
  # ii. run HIMA with the current dataset and principal components ----
  result_hima_late_integration <- hima(X = exposure,
                                       Y = outcome,
                                       M = scores_df, 
                                       COV.MY = covs,
                                       COV.XM = covs,
                                       Y.family = c("gaussian"),
                                       M.family = c("gaussian"), 
                                       scale = FALSE) 
  
  # Filter significant pcs, create scaled %TE variable
  result_hima_late_sig <- result_hima_late_integration %>%
    as_tibble(rownames = "lf_num") %>%
    filter(BH.FDR < 0.05) %>%
    mutate(
      te_direction = if_else(beta<0, -1*`% total effect`, `% total effect`), 
      `% Total Effect scaled` = 100*`% total effect`/sum(`% total effect`)) 
  
  result_hima_late_sig <- result_hima_late_sig  %>% 
    mutate(lf_numeric = str_split_fixed(lf_num, fixed("_"), 2)[,2],
           omic_layer = str_split_fixed(lf_num, fixed("_"), 2)[,1] %>%
             str_to_sentence(),
           multiomic_mthd = "Late")%>%
    mutate(omic_layer = str_replace(omic_layer, "Mirna", "miRNA" ),
           omic_pc = str_c(omic_layer, " ", lf_numeric) %>%
             str_replace("PC", "Comp. ")) %>%
    mutate(lf_ordered = forcats::fct_reorder(omic_pc, `te_direction`)) %>%
    dplyr::select(multiomic_mthd, omic_pc, omic_layer, lf_num, lf_ordered,
                  alpha, beta, `% Total Effect scaled`) %>%
    rename(Alpha = alpha, 
           Beta = beta,
           `TE (%)` = `% Total Effect scaled`)%>%
    mutate(omic_num = case_when(str_detect(lf_num, "meth") ~ 1, 
                                str_detect(lf_num, "transc") ~ 2, 
                                str_detect(lf_num, "miR") ~ 3,
                                str_detect(lf_num, "pro") ~ 4, 
                                str_detect(lf_num, "met") ~ 5))
  
  
  # Extract variable correlation with principal components
  scores_full_list_late_int <- map2(omics_lst, 
                                    names(omics_lst),
                                    ~run_pca(.x, .y)$scores_full)
  
  scores_df_full <- purrr::reduce(scores_full_list_late_int, cbind) %>% as.data.frame()
  
  # Get correlation of omics and PCs by omic layer
  var.cor <- map_df(names(omics_lst), function(type) {
    cor(omics_lst[[type]], scores_full_list_late_int[[type]]) %>%
      as.data.frame()
  })
  
  # Select only significant PCs
  ftr_cor_sig_pcs_late <- var.cor %>%
    dplyr::select(result_hima_late_sig$lf_num) %>%
    as_tibble(rownames = "feature") %>%
    left_join(meta_df, by = c("feature" = "ftr_name"))
  
  res = list(result_hima_late_sig = result_hima_late_sig, 
             result_ftr_cor_sig_pcs_late = ftr_cor_sig_pcs_late,
             intergration_type = "Late")
  return(res)
  
}

6.2.4 Plot Mediation

Code
#' Plot result of mediation analysis with latent factors
#' 
#' Given a list including two tidy dataframes and one integration type vector 
#' summarizing  mediation analysis result
#' function runs plotting and returns a plot
#'
#' @param med_lf_list A list including two tidy dataframes 
#' summarizing the results of HIMA analysis
#' and one vector indicating integration type

#'
#' @return a figure of the result of mediation with latent factors in given integration
#' 
#' @import dplyr
#' @importFrom dplyr left_join
#' @importFrom ggplot2 ggplot
#' @importFrom base apply
#' 

plot_med_lf <- function(med_lf_list) {
  
  # Panel A Bargraph of mediation effects of latent factors --------------
  med_long <- med_lf_list[[1]]%>%
    pivot_longer(cols = c(Alpha, Beta, `TE (%)`))
  
  # Plot 
  panel_a <- ggplot(med_long, aes(x = lf_ordered, y = value)) +
    geom_bar(stat = "identity", fill = "grey50") + 
    geom_hline(yintercept = 0) +
    facet_grid(name ~ omic_num, scales = "free", space = "free_x", switch = "y") + 
    ggh4x::facetted_pos_scales(
      y = list(name == "Alpha"  ~ scale_y_continuous(
        limits = c(-.45,.45), breaks = c(-0.4, 0, .4)),
        name == "Beta"   ~ scale_y_continuous(
          limits = c(-.45,.45), breaks = c(-0.4, 0, .4)),
        name == "TE (%)" ~ scale_y_continuous(
          limits = c(-1,55), n.breaks = 4))) + 
    theme(axis.title = element_blank(), 
          strip.placement = "outside",
          strip.text.x = element_blank(),
          strip.text.y = element_text(angle = 0, size = 9),
          axis.text.x = element_blank(),
          strip.background = element_blank(),
          axis.text.y = element_text(size = 8))
  
  # Panel B: heatmap of correlation of features vs PC's --------------
  
  # Select features for plots -------
  ## Select rows with values in the top 10 of their respective columns 
  
  top <- apply(med_lf_list[[2]] %>% 
                 select(-omic_layer, -omic_num) %>% 
                 janitor::remove_empty(which = "rows") %>%
                 column_to_rownames("feature"), 2,
               function(x) x %in% tail(sort(abs(x)), 10))
  
  ## Filter selected rows
  ftr_cor_sig_lf_top_ft <- med_lf_list[[2]][which(rowSums(top) > 0), ] 
  
  # Pivot longer
  ftr_cor_sig_lf_top_ft_l <-
    ftr_cor_sig_lf_top_ft %>%
    pivot_longer(cols = all_of(med_lf_list[[1]]$lf_num),
                 names_to = "lf_num",
                 values_to = "Correlation")
  
  # Change features which are in wrong JIVE individual component to zero
  if(med_lf_list[[3]] == "Intermediate") {
    ftr_cor_sig_lf_top_ft_l <- ftr_cor_sig_lf_top_ft_l %>%
      mutate(
        in_ind_omic = str_sub(lf_num, 1, 5) == str_sub(omic_layer, 1, 5),
        Correlation = ifelse(in_ind_omic | str_detect(lf_num, "Joint"),
                             Correlation, NA)) 
  }
  
  # Join with long format of mediation effects of latent factors 
  # to get the ordered latent factor
  panel_b_dat_top_ft <- left_join(ftr_cor_sig_lf_top_ft_l,
                                  med_long %>%
                                    filter(name == "Alpha") %>%
                                    dplyr::select(lf_num, lf_ordered),
                                  by = "lf_num") %>%
    mutate(omic_num2 = case_when(str_detect(lf_num, "meth") ~ 1, 
                                 str_detect(lf_num, "transc") ~ 2, 
                                 str_detect(lf_num, "miR") ~ 3,
                                 str_detect(lf_num, "pro") ~ 4, 
                                 str_detect(lf_num, "met") ~ 5,
                                 TRUE ~ 0))
  
  
  # Plot
  panel_b <- ggplot(data = panel_b_dat_top_ft,
                    aes(y = feature,
                        x = lf_ordered, 
                        fill = Correlation)) +
    geom_tile(color = "white") +
    facet_grid(omic_layer ~ omic_num2, scales = "free", space = "free") +
    scale_fill_gradient2(low  = "blue",
                         mid  = "white",
                         high = "red",
                         midpoint = 0,
                         limits = c(-1, 1),
                         breaks = c(-1, 0, 1),
                         na.value = "grey50") +
    theme(
      axis.text.x = element_text(size = 8,angle = 90, hjust = 1, vjust = .5),
      axis.text.y = element_text(size = 8),
      strip.text = element_blank(),
      axis.title = element_blank(), 
      axis.ticks.x = element_blank(),
      legend.position = "none",
      text = element_text(size = 8)) 
  
  # Combine Figures 
  p <- cowplot::plot_grid(
    NULL, panel_a,  NULL, panel_b, 
    ncol = 1, align = "v", axis = "lr",
    rel_heights  = c(.05, .6, .1, 1.75),
    labels = c("a)","", "b) "))
  
  return(p)
}

6.3 Integrated/Quasi Mediation Analysis Code

6.3.1 Early Integration of omics datasets

Code
#' Plot Sankey Diagram for LUCID in Early integration
#' 
#' Given an object of class from LUCID
#'
#' @param lucid_fit1  an object of class from LUCID
#' @param text_size  size of the text in sankey diagram
#'
#' @return a Sankey Diagram for LUCID in Early integration
#'
#' @import dplyr
#' @importFrom ggplot2 ggplot



sankey_early_integration <- function(lucid_fit1, text_size = 15) {
  # Get sankey dataframe ----
  get_sankey_df <- function(x,
                            G_color = "dimgray", 
                            X_color = "#eb8c30",
                            Z_color = "#2fa4da", 
                            Y_color = "#afa58e", 
                            pos_link_color = "#67928b", 
                            neg_link_color = "#d1e5eb", 
                            fontsize = 10) {
    K <- x$K
    var.names <- x$var.names
    pars <- x$pars
    dimG <- length(var.names$Gnames)
    dimZ <- length(var.names$Znames)
    valueGtoX <- as.vector(t(x$pars$beta[, -1]))
    valueXtoZ <- as.vector(t(x$pars$mu))
    valueXtoY <- as.vector(x$pars$gamma$beta)[1:K]
    
    # GtoX
    GtoX <- data.frame(
      source = rep(x$var.names$Gnames, K), 
      target = paste0("Latent Cluster", 
                      as.vector(sapply(1:K, function(x) rep(x, dimG)))), 
      value = abs(valueGtoX), 
      group = as.factor(valueGtoX > 0))
    
    # XtoZ
    XtoZ <- data.frame(
      source = paste0("Latent Cluster", 
                      as.vector(sapply(1:K, 
                                       function(x) rep(x, dimZ)))), 
      target = rep(var.names$Znames, 
                   K), value = abs(valueXtoZ),
      group = as.factor(valueXtoZ > 
                          0))
    
    # subset top 25% of each omics layer
    top25<- XtoZ %>%
      filter(source == "Latent Cluster1") %>%
      mutate(omics = case_when(grepl("cg", target) ~ "Methylation",
                               grepl("tc", target) ~ "Transcriptome",
                               grepl("miR", target) ~ "miRNA")) %>%
      group_by(omics) %>%
      arrange(desc(value)) %>%
      slice(1:7) %>%
      ungroup()
    
    XtoZ_sub<- XtoZ %>%
      filter(target %in% top25$target)
    
    
    # XtoY
    XtoY <- data.frame(source = paste0("Latent Cluster", 1:K), 
                       target = rep(var.names$Ynames, K), value = abs(valueXtoY), 
                       group = as.factor(valueXtoY > 0))
    links <- rbind(GtoX, XtoZ_sub, XtoY)
    # links <- rbind(GtoX, XtoZ, XtoY)
    
    nodes <- data.frame(
      name = unique(c(as.character(links$source), 
                      as.character(links$target))), 
      group = as.factor(c(rep("exposure",
                              dimG), rep("lc", K), rep("biomarker", nrow(XtoZ_sub)/2), "outcome")))
    # group = as.factor(c(rep("exposure", 
    # dimG), rep("lc", K), rep("biomarker", dimZ), "outcome")))
    ## the following two lines were used to exclude covars from the plot
    links <- links %>% filter(!grepl("cohort", source) & 
                                !grepl("age", source) & 
                                !grepl("fish", source) &
                                !grepl("sex", source))
    nodes <- nodes %>% filter(!grepl("cohort", name) &
                                !grepl("age", name) & 
                                !grepl("fish", name) &
                                !grepl("sex", name)) 
    
    links$IDsource <- match(links$source, nodes$name) - 1
    links$IDtarget <- match(links$target, nodes$name) - 1
    
    color_scale <- data.frame(
      domain = c("exposure", "lc", "biomarker", 
                 "outcome", "TRUE", "FALSE"), 
      range = c(G_color, X_color, 
                Z_color, Y_color, pos_link_color, neg_link_color))
    
    sankey_df = list(links = links, 
                     nodes = nodes)
    return(sankey_df)
  }
  # 1. Get sankey dataframes ----
  sankey_dat <- get_sankey_df(lucid_fit1)
  n_omics <- length(lucid_fit1$var.names$Znames)
  # link data
  links <- sankey_dat[["links"]] 
  # node data
  nodes <- sankey_dat[["nodes"]] 
  
  nodes1 <- nodes %>% 
    mutate(group = case_when(str_detect(name,"Cluster") ~ "lc",
                             str_detect(name, "cg") ~ "CpG",
                             str_detect(name, "outcome") ~ "outcome",
                             str_detect(name, "pro") ~ "Prot",
                             str_detect(name, "met") ~ "Met",
                             str_detect(name, "tc") ~ "TC",
                             str_detect(name, "miR") ~ "miRNA",
                             str_detect(name, "G1") ~ "exposure"),
           name = ifelse(name == "G1", "Hg",name))
  links1 <- links %>%
    mutate(source = ifelse(source == "G1", "Hg",source))
  # 6. Plotly Version ----
  
  ## 6.1 Set Node Color Scheme: ----
  color_pal_sankey <- matrix(
    c("exposure", sankey_colors$range[sankey_colors$domain == "exposure"],
      "lc",       "#b3d8ff",
      "CpG",     sankey_colors$range[sankey_colors$domain == "layer1"],
      "TC",      sankey_colors$range[sankey_colors$domain == "layer2"],
      "miRNA", sankey_colors$range[sankey_colors$domain == "layer3"],
      "outcome",  sankey_colors$range[sankey_colors$domain == "Outcome"]), 
    ncol = 2, byrow = TRUE) %>%
    as_tibble(.name_repair = "unique") %>% 
    janitor::clean_names() %>%
    dplyr::rename(group = x1, color = x2)
  
  # Add color scheme to nodes
  nodes_new_plotly <- nodes1 %>% 
    left_join(color_pal_sankey) %>%
    mutate(
      x = case_when(
        group == "exposure" ~ 0,
        str_detect(name, "Cluster") ~ 1/3,
        str_detect(name, "cg")|
          str_detect(name, "tc")|
          str_detect(name, "miR")|
          str_detect(name, "outcome")~ 2/3
      ))
  
  nodes_new_plotly1 <- nodes_new_plotly %>%
    # Modify names of features for plotting
   dplyr::select(group, color, x, name)%>% 
    mutate(name = case_when(name == "value" ~ "<b>Hg</b>",
                            name == "Latent Cluster1" ~ "<b>Joint Omics\nProfile 0</b>",
                            name == "Latent Cluster2" ~ "<b>Joint Omics\nProfile 1</b>",
                            TRUE ~ name))
    
  
  ## 6.2 Get links for Plotly, set color ----
  links_new <- links1  %>%
    mutate(
      link_color = case_when(
        # Ref link color
        value == 0 ~     "#f3f6f4",
        # # Cluster 
        # str_detect(source, "Cluster1") &  group == TRUE  ~  "#706C6C",
        # str_detect(source, "Cluster1") &  group == FALSE ~  "#D3D3D3",
        # str_detect(source, "Cluster2") &  group == TRUE  ~  "#706C6C",
        # str_detect(source, "Cluster2") &  group == FALSE ~  "#D3D3D3",
        ##############
        # Exposure
        str_detect(source, "Hg") &  group == TRUE  ~  "red",
            # Outcome
        str_detect(target, "outcome") &  group == TRUE  ~  "red",
        # Methylation 
        str_detect(target, "tc") &  group == TRUE  ~  "#bf9000",
        str_detect(target, "tc") &  group == FALSE ~  "#ffd966",
        # Transcriptome
        str_detect(target, "cg") &  group == TRUE  ~  "#38761d",
        str_detect(target, "cg") &  group == FALSE ~  "#b6d7a8",
        # proteome
        str_detect(target, "miR") &  group == TRUE  ~  "#a64d79",
        str_detect(target, "miR") &  group == FALSE ~  "#ead1dc",
        ##
        group == FALSE ~ "#D3D3D3", # Negative association
        group == TRUE ~  "#706C6C")) # Positive association
  
  links_new1<- links_new %>%
   dplyr::select(colnames(links_new), target)
    
  plotly_link <- list(
    source = links_new1$IDsource,
    target = links_new1$IDtarget,
    value = links_new1$value+.00000000000000000000001, 
    color = links_new1$link_color)  
  
  # Get list of nodes for Plotly
  plotly_node <- list(
    label = nodes_new_plotly1$name, 
    color = nodes_new_plotly1$color,
    pad = 15,
    thickness = 20,
    line = list(color = "black",width = 0.5),
    x = nodes_new_plotly1$x, 
    # y = c(0.01, 
    #       0.3, 0.7, # clusters
    #       seq(from = .01, to = 1, by = 0.04)[1:(dimZ * 0.25)], # biomaker
    #       .95
    y = c(0.01,
          0.1, 0.5, # clusters
          seq(from = .05, to = 1, by = 0.04)[1:21],
          # seq(from = (.01+0.06*7), to = 1, by = 0.08)[1:5],
          # 0.9,
          # biomaker
          0.98
  ))
  
  
  ## 6.3 Plot Figure ----
  (fig <- plot_ly(
    type = "sankey",
    domain = list(
      x =  c(0,1),
      y =  c(0,1)),
    orientation = "h",
    node = plotly_node,
    link = plotly_link))
  
  (fig <- fig %>% layout(
    # title = "Basic Sankey Diagram",
    font = list(
      size = text_size
    ))
  )
  return(fig)
}

6.3.2 Intermediate Integration of omics datasets

6.3.3 Late Integration of omics datasets

Code
# Plot Lucid In Serial Function ------
source(fs::path(dir_proj, "functions", "lucid_reorder_plot_without_y.R"))

# Get sankey dataframe
get_sankey_df <- function(x,
                          G_color = "dimgray", 
                          X_color = "#eb8c30",
                          Z_color = "#2fa4da", 
                          Y_color = "#afa58e", 
                          pos_link_color = "#67928b", 
                          neg_link_color = "#d1e5eb", 
                          fontsize = 7) {
  K <- x$K
  var.names <- x$var.names
  pars <- x$pars
  dimG <- length(var.names$Gnames)
  dimZ <- length(var.names$Znames)
  valueGtoX <- as.vector(t(x$pars$beta[, -1]))
  valueXtoZ <- as.vector(t(x$pars$mu))
  valueXtoY <- as.vector(x$pars$gamma$beta)[1:K]
  
  # GtoX
  GtoX <- data.frame(
    source = rep(x$var.names$Gnames, K), 
    target = paste0("Latent Cluster", 
                    as.vector(sapply(1:K, function(x) rep(x, dimG)))), 
    value = abs(valueGtoX), 
    group = as.factor(valueGtoX > 0))
  
  # XtoZ
  XtoZ <- data.frame(
    source = paste0("Latent Cluster", 
                    as.vector(sapply(1:K, 
                                     function(x) rep(x, dimZ)))), 
    target = rep(var.names$Znames, 
                 K), value = abs(valueXtoZ),
    group = as.factor(valueXtoZ > 
                        0))
  # XtoY
  XtoY <- data.frame(source = paste0("Latent Cluster", 1:K), 
                     target = rep(var.names$Ynames, K), value = abs(valueXtoY), 
                     group = as.factor(valueXtoY > 0))
  
  links <- rbind(GtoX, XtoZ, XtoY)
  
  nodes <- data.frame(
    name = unique(c(as.character(links$source), 
                    as.character(links$target))), 
    group = as.factor(c(rep("exposure", 
                            dimG), rep("lc", K), rep("biomarker", dimZ), "outcome")))
  
  ## the following two lines were used to exclude covars from the plot #HW added
  links <- links %>% filter(!grepl("cohort", source) & 
                              !grepl("age", source) & 
                              !grepl("fish", source) &
                              !grepl("sex", source))
  nodes <- nodes %>% filter(!grepl("cohort", name) &
                              !grepl("age", name) & 
                              !grepl("fish", name) &
                              !grepl("sex", name))  
  
  
  links$IDsource <- match(links$source, nodes$name) - 1
  links$IDtarget <- match(links$target, nodes$name) - 1
  
  color_scale <- data.frame(
    domain = c("exposure", "lc", "biomarker", 
               "outcome", "TRUE", "FALSE"), 
    range = c(G_color, X_color, 
              Z_color, Y_color, pos_link_color, neg_link_color))
  
  sankey_df = list(links = links, 
                   nodes = nodes)
  
  # p <- sankeyNetwork(
  #   Links = sankey_df$links, 
  #   Nodes = sankey_df$nodes, 
  #   Source = "IDsource", 
  #   Target = "IDtarget",
  #   Value = "value", 
  #   NodeID = "name", 
  #   colourScale = JS(sprintf("d3.scaleOrdinal()\n .domain(%s)\n .range(%s)\n ", 
  #                            jsonlite::toJSON(color_scale$domain), 
  #                            jsonlite::toJSON(color_scale$range))), 
  #   LinkGroup = "group", 
  #   NodeGroup = "group", 
  #   sinksRight = FALSE, 
  #   fontSize = fontsize)
  # p
  return(sankey_df)
}


# lucid_fit1 <- fit1 
# lucid_fit2 <- fit2 
# lucid_fit3 <- fit3 

# sankey_in_serial Function ----
sankey_in_serial <- function(lucid_fit1, lucid_fit2, lucid_fit3, color_pal_sankey, text_size = 15) {
  
  # 1. Get sankey dataframes ----
  sankey_dat1 <- get_sankey_df(lucid_fit1)
  sankey_dat2 <- get_sankey_df(lucid_fit2)
  sankey_dat3 <- get_sankey_df(lucid_fit3)
  
  n_omics_1 <- length(lucid_fit1$var.names$Znames)
  n_omics_2 <- length(lucid_fit2$var.names$Znames)
  n_omics_3 <- length(lucid_fit3$var.names$Znames)
  
  # combine link data
  lnks1_methylation <- sankey_dat1[["links"]] %>% mutate(analysis = "1_methylation")
  lnks2_miRNA  <- sankey_dat2[["links"]] %>% mutate(analysis = "2_miRNA")
  lnks3_transcription    <- sankey_dat3[["links"]] %>% mutate(analysis = "3_transcript")
  links <- bind_rows(lnks1_methylation, lnks2_miRNA, lnks3_transcription)
  
  # combine node data
  nodes1_methylation <- sankey_dat1[["nodes"]] %>% mutate(analysis = "1_methylation")
  nodes2_miRNA  <- sankey_dat2[["nodes"]] %>% mutate(analysis = "2_miRNA")
  nodes3_transcription    <- sankey_dat3[["nodes"]] %>% mutate(analysis = "3_transcript")
  nodes <- bind_rows(nodes1_methylation, nodes2_miRNA, nodes3_transcription)
  
  
  # 2. Modify analysis 1 ----
  # For analysis 1, latent clusters need to be renamed to names from analysis 2:
  ## 2.1 Get new and original latent cluster names (from the next analysis) ----
  names_clusters_1 <- data.frame(
    name_og = c("Latent Cluster1", "Latent Cluster2"), 
    name_new = c("<b>Methylation\nProfile 0</b>", "<b>Methylation\nProfile 1</b>"))
  
  ## 2.2 Change link names ----
  # Change link names and 
  lnks1_methylation_new <- sankey_dat1[["links"]] %>%
    mutate(
      analysis = "1_methylation",
      source = case_when(
        source == names_clusters_1$name_og[1] ~ names_clusters_1$name_new[1],
        source == names_clusters_1$name_og[2] ~ names_clusters_1$name_new[2],
        TRUE ~ source),
      target = case_when(
        target == names_clusters_1$name_og[1] ~ names_clusters_1$name_new[1],
        target == names_clusters_1$name_og[2] ~ names_clusters_1$name_new[2],
        TRUE ~ target)) %>%
    filter(target != "outcome")
  
  ## 2.3 Change node names ----
  # first, change latent cluster names to analysis specific cluster names
  nodes1_methylation_new <- sankey_dat1[["nodes"]] %>%
    mutate(
      name = case_when(
        name == names_clusters_1$name_og[1] ~ names_clusters_1$name_new[1],
        name == names_clusters_1$name_og[2] ~ names_clusters_1$name_new[2],
        TRUE ~ name), 
      group = if_else(group == "biomarker", "CpG", as.character(group))) %>%
    filter(group != "outcome")
  
  
  # Visualize
  # sankeyNetwork(
  #   Links = lnks1_methylation_new,
  #   Nodes = nodes1_methylation_new,
  #   Source = "IDsource", Target = "IDtarget",
  #   Value = "value", NodeID = "name", LinkGroup = "group", NodeGroup = "group",
  #   sinksRight = FALSE)
  
  
  # 3. Modify analysis 2 ----
  # For analysis 2, latent clusters need to be renamed to names from analysis 3:
  ## 3.1 Get new and og latent cluster names ----
  names_clusters_2 <- data.frame(
    name_og = c("Latent Cluster1", "Latent Cluster2"), 
    name_new = c("<b>miRNA\nProfile 0</b>", "<b>miRNA\nProfile 1</b>"))
  
  ## 3.2 Change cluster names ----
  lnks2_miRNA_new <- sankey_dat2[["links"]] %>% 
    mutate(
      analysis = "2_miRNA", 
      source = case_when(
        source == names_clusters_2$name_og[1] ~ names_clusters_2$name_new[1], 
        source == names_clusters_2$name_og[2] ~ names_clusters_2$name_new[2], 
        TRUE ~ source), 
      target = case_when(
        target == names_clusters_2$name_og[1] ~ names_clusters_2$name_new[1], 
        target == names_clusters_2$name_og[2] ~ names_clusters_2$name_new[2], 
        TRUE ~ target)) %>%
    filter(target != "outcome")
  
  ## 3.3 Change node names ----
  nodes2_miRNA_new <- sankey_dat2[["nodes"]] %>% 
    mutate(
      name = case_when(
        name == names_clusters_2$name_og[1] ~ names_clusters_2$name_new[1], 
        name == names_clusters_2$name_og[2] ~ names_clusters_2$name_new[2], 
        TRUE ~ name), 
      group = case_when(group == "exposure" ~ "lc", 
                        group == "biomarker" ~ "miRNA",
                        TRUE ~ as.character(group))) %>%
    filter(name != "outcome")
  
  # Visualize
  # sankeyNetwork(
  #   Links = lnks2_transcript_new, 
  #   Nodes = nodes2_transcript_new,
  #   Source = "IDsource", Target = "IDtarget",
  #   Value = "value", NodeID = "name", 
  #   LinkGroup = "group", NodeGroup = "group",
  #   sinksRight = FALSE)
  ##
  
  # 4. Modify analysis 3 ----
  # For analysis 2, latent clusters need to be renamed to names from analysis 3:
  ## 4.1 Get new and og latent cluster names ----
  names_clusters_3 <- tibble(
    name_og = c("Latent Cluster1", "Latent Cluster2"),
    name_new = c("<b>Transcriptome\nProfile 0</b>", "<b>Transcriptome\nProfile 1</b>")) 
  
  
  ## 4.2 Change cluster names ----
  lnks3_transcript_new <- sankey_dat3[["links"]] %>% 
    mutate(
      analysis = "3_transcript", 
      source = case_when(
        source == names_clusters_3$name_og[1] ~ names_clusters_3$name_new[1], 
        source == names_clusters_3$name_og[2] ~ names_clusters_3$name_new[2], 
        TRUE ~ source), 
      target = case_when(
        target == names_clusters_3$name_og[1] ~ names_clusters_3$name_new[1], 
        target == names_clusters_3$name_og[2] ~ names_clusters_3$name_new[2], 
        TRUE ~ target))
  
  ## 4.3 Change node names ----
  nodes3_transcript_new <- sankey_dat3[["nodes"]] %>% 
    mutate(
      name = case_when(
        name == names_clusters_3$name_og[1] ~ names_clusters_3$name_new[1], 
        name == names_clusters_3$name_og[2] ~ names_clusters_3$name_new[2], 
        TRUE ~ name), 
      group = case_when(group == "exposure" ~ "lc", 
                        group == "biomarker" ~ "TC",
                        TRUE ~ as.character(group)))
  
  # Test/Visualize
  # sankeyNetwork(
  #   Links = lnks3_protein_new, 
  #   Nodes = nodes3_protein_new,
  #   Source = "IDsource", Target = "IDtarget",
  #   Value = "value", NodeID = "name", LinkGroup = "group", NodeGroup = "group",
  #   sinksRight = FALSE)
  
  
  
  # 5. Combine analysis 1-3 ----
  
  ## 5.1 Final Links ----
  links_all_1 <- bind_rows(lnks1_methylation_new, 
                           lnks2_miRNA_new,
                           lnks3_transcript_new) %>%
    dplyr::select(-IDsource, -IDtarget)
  
  
  ### 5.1.1 Arrange by magnitude ----
  omics_priority <- links_all_1 %>% 
    filter(str_detect(source, "Profile 0"), 
                    str_detect(target, "Profile 0", negate = TRUE), 
                    str_detect(target, "Profile 1", negate = TRUE), 
                    str_detect(target, "outcome", negate = TRUE)) %>%
    group_by(source) %>%
    arrange(desc(group), desc(value), .by_group = TRUE) %>%
    mutate(omics_order = row_number()) %>%
    ungroup() %>%
    dplyr::select(target, omics_order)
  
  
  
  links_all <- links_all_1 %>%
    left_join(omics_priority) %>%
    mutate(
      # arrange_me = if_else(is.na(omics_order), 
      #                           "dont_arrange", 
      #                           "arrange"), 
      row_num = row_number(), 
      # row_num_order_comb = if_else(is.na(omics_order), 
      #                              row_num, 
      #                              omics_order), 
      row_num_to_add = if_else(is.na(omics_order), 
                               as.numeric(row_num), 
                               NA_real_) %>%
        zoo::na.locf(),
      order = if_else(is.na(omics_order), 
                      row_num_to_add, 
                      row_num_to_add+omics_order)
    ) %>%
    arrange(order)
  
  
  ### 5.1.2 Get new source and target IDs ----
  # First, combine all layers, get unique identifier
  node_ids <- tibble(name = unique(c(unique(links_all$source), 
                                     unique(links_all$target)))) %>%
    mutate(ID = row_number()-1)
  
  # Then combine with original data 
  links_new <- links_all %>%
    left_join(node_ids, by = c("source" = "name")) %>%
    dplyr::rename(IDsource = ID) %>%
    left_join(node_ids, by = c("target" = "name")) %>%
    dplyr::rename(IDtarget = ID)
  
  
  ## 5.2 Final Nodes ----
  nodes_new <- node_ids %>%
    dplyr::select(name) %>%
    left_join(bind_rows(nodes1_methylation_new, 
                                 nodes2_miRNA_new,
                                 nodes3_transcript_new))
  # remove duplicates 
  nodes_new_nodup <- nodes_new[!base::duplicated(nodes_new),] %>%
    base::as.data.frame()
  
  
  # 6. Plotly Version ----
  
  # Add color scheme to nodes
  nodes_new_plotly <- nodes_new_nodup %>% 
    left_join(color_pal_sankey) %>%
    mutate(
      x = case_when(
        group == "exposure" ~ 0,
        str_detect(name, "Methylation") ~ 1/5, 
        str_detect(name, "miRNA") | 
          str_detect(group, "CpG") ~ 2/5, 
        str_detect(name, "Transcriptome") | 
          str_detect(group, "miRNA") ~ 3/5, 
        str_detect(group, "TC") ~  4/5, 
        str_detect(group, "outcome") ~ 4.5/5, 
      ))
  
  
  ## 6.2 Get links for Plotly, set color ----
  links_new <- links_new  %>%
    mutate(
      link_color = case_when(
        # Ref link color
        value == 0 ~     "#f3f6f4", 
        # Methylation 
        str_detect(target, "outcome") &  group == TRUE  ~  "red",
        
        str_detect(source, "Transcriptome") &  group == TRUE  ~  "#bf9000",
        str_detect(source, "Transcriptome") &  group == FALSE ~  "#ffd966",
        # Transcriptome
        str_detect(source, "Methylation") &  group == TRUE  ~  "#38761d",
        str_detect(source, "Methylation") &  group == FALSE ~  "#b6d7a8",
        # proteome
        str_detect(source, "miRNA") &  group == TRUE  ~  "#a64d79",
        str_detect(source, "miRNA") &  group == FALSE ~  "#ead1dc",
        
        links_new$group == FALSE ~ "#d9d2e9", # Negative association
        links_new$group == TRUE ~  "red")) # Positive association
  
  plotly_link <- list(
    source = links_new$IDsource,
    target = links_new$IDtarget,
    value = links_new$value+.00000000000000000000001, 
    color = links_new$link_color)  
  
  
  # Get list of nodes for Plotly
  plotly_node <- list(
    label = nodes_new_plotly$name, 
    color = nodes_new_plotly$color,
    pad = 15,
    thickness = 20,
    line = list(color = "black",width = 0.5), 
    x = nodes_new_plotly$x, 
    y = c(0.01, 
          0.1, 0.3, # Methylation clusters
          .45, .55, # Transcriptome clusters
          .80, .95, # Proteome clusters
          seq(from = .01, to = 1, by = 0.035)[1:n_omics_1], # Cpgs (10 total)
          seq(from = 0.35, to = 1, by = 0.025)[1:n_omics_2], # miRNA (8 total)
          seq(from = 0.75, to = 1, by = 0.03)[1:n_omics_3], # Transcript (10 total)
          .95
    ))
  
  
  ## 6.3 Plot Figure ----
  fig <- plot_ly(
    type = "sankey",
    domain = list(
      x =  c(0,1),
      y =  c(0,1)),
    orientation = "h",
    node = plotly_node,
    link = plotly_link)
  
  (fig <- fig %>% layout(
    # title = "Basic Sankey Diagram",
    font = list(
      size = text_size
    ))
  )
  
  return(fig)
}

6.3.4 Plot Omics Profile

Code
#' Plot of Omics profiles for each cluster using LUCID
#' 
#' Given an object of class from LUCID
#'
#' @param fit an object of class from LUCID
#' @param integration_type type of integreation,, "Early" or "Intermediate"
#'
#' @return a figure of Omics profiles for each cluster using LUCID
#'
#' @import dplyr
#' @importFrom ggplot2 ggplot


plot_omics_profiles <- function(fit, integration_type) {
  if(integration_type == "Early"){
    M_mean = as.data.frame(fit$pars$mu)
    M_mean$cluster = as.factor(1:2)
    # Reshape the data
    M_mean_melt <- M_mean %>% 
      pivot_longer(cols = -cluster, names_to = "variable", values_to = "value")
    
    M_mean_melt <- M_mean_melt %>% 
      mutate(cluster = paste0("Cluster ", cluster))
    # add color label for omics layer
    M_mean_melt = M_mean_melt %>%
      mutate(color_label = case_when(str_detect(variable,  "cg") ~ "1", 
                                     str_detect(variable, "tc") ~ "2", 
                                     TRUE ~ "3"))
    
    fig <- ggplot(M_mean_melt, 
                  aes(fill = color_label, y = value, x = variable)) +
      geom_bar(position="dodge", stat="identity") +
      ggtitle("Omics profiles for the two latent clusters") +
      facet_grid(rows = vars(cluster), scales = "free_y") +
      theme(legend.position="none") +
      geom_hline(yintercept = 0) +
      xlab("") +
      theme(text = element_text(size=10),
            axis.text.x = element_text(angle = 90, vjust = 1,
                                       hjust = 1),
            plot.margin = margin(10, 10, 10, 80),
            panel.background = element_rect(fill="white"), 
            strip.background = element_rect(fill = "white"),
            axis.line.x = element_line(color = "black"),
            axis.line.y = element_line(color = "black"),) +
      scale_fill_manual(values = c("#2fa4da", "#A77E69", "#e7b6c1"))
  } else if(integration_type == "Intermediate"){
    M_mean = as_tibble(fit$res_Mu_Sigma$Mu[[1]], rownames = "variable") %>%
      bind_rows(as_tibble(fit$res_Mu_Sigma$Mu[[2]], rownames = "variable")) %>%
      bind_rows(as_tibble(fit$res_Mu_Sigma$Mu[[3]], rownames = "variable"))
    
    # Reorder results because mirna order is reversed
    M_mean1 <- M_mean %>% 
      left_join(meta_df, by = c("variable" = "ftr_name")) %>%
      mutate(`Low Risk`  =  if_else(omic_layer == "miRna", V2, V1), 
             `High Risk` =  if_else(omic_layer == "miRna", V1, V2)) %>%
      dplyr::select(-c("V1", "V2"))
    
    # Pivot longer for figure 
    M_mean_l <- M_mean1 %>% 
      pivot_longer(cols = c(`Low Risk`, `High Risk`),
                   names_to = "cluster",
                   values_to = "value")
    
    # add color label for omics layer
    M_mean2 = M_mean_l %>%
      mutate(color_label = case_when(omic_layer == "methylome" ~ "1", 
                                     omic_layer == "transcriptome" ~ "2", 
                                     omic_layer == "miRna" ~ "3"), 
             low_high = if_else(str_detect(cluster, "Low"), 0,1),
             omic = if_else(omic_layer == "miRna", 
                            "miR",
                            str_sub(omic_layer, end = 1) %>% toupper()),
             omic_cluster = str_c(omic, low_high))
    
    # Filter only the top ## differential expressed features 
    M_mean2_top <- M_mean2 %>% 
      group_by(variable) %>% 
      filter(abs(value) == max(abs(value))) %>% 
      ungroup() %>% 
      arrange(max(abs(value))) %>% 
      group_by(omic_layer) %>% 
      slice_head(n=12) %>%
      ungroup()
    
    # Plots top 12 features
    fig <- ggplot(M_mean2  %>% filter(variable %in% M_mean2_top$variable),
                  aes(fill = color_label, y = value, x = variable)) +
      geom_bar(position="dodge", stat="identity") +
      ggtitle("Omics profiles for 2 latent clusters - Lucid in Parallel") +
      facet_grid(rows = vars(cluster),
                 cols = vars(omic_layer), scales = "free_x", space = "free") +
      theme(legend.position="none") +
      geom_hline(yintercept = 0) +
      xlab("") +
      theme(text = element_text(size=10),
            axis.text.x = element_text(angle = 90, vjust = 1,
                                       hjust = 1),
            plot.margin = margin(10, 10, 10, 80),
            panel.background = element_rect(fill="white"),
            strip.background = element_rect(fill = "white"),
            axis.line.x = element_line(color = "black"),
            axis.line.y = element_line(color = "black"),) +
      scale_fill_manual(values = c("#2fa4da", "#A77E69", "#e7b6c1"))
  }
  
  return(fig)
}